import os

from agents.sub_agent import SubAgent
from tools.arguments import parse
from tools.gridworld_evaluators import evaluate_gridworld

args = parse()
num_threads = 1 if args.device == "cpu" else 4
str_threads=str(num_threads)
os.environ["OMP_NUM_THREADS"] = str_threads
os.environ["OPENBLAS_NUM_THREADS"] = str_threads
os.environ["MKL_NUM_THREADS"] = str_threads
os.environ["VECLIB_MAXIMUM_THREADS"] = str_threads
os.environ["NUMEXPR_NUM_THREADS"] = str_threads

import gym
import numpy as np
from agents.interfaces import Learner, LearnerAlone
from agents.coord_agent import CoordAgent
from algos.bandit import BanditObs, DijkstraBandit
from algos.l_info import DistopInfo
from algos.networks import MLP, Actor_SAC, SkewfitConvNet
from algos.sac import ContSAC
from replay.buffers import GngManager
from tools.evaluators import evaluate_robotic, multiworld_images, evaluate_maze, eval_maze

import datetime
import torch
torch.set_num_threads(num_threads)
torch.set_num_interop_threads(num_threads)
from torch import nn
from tools import utils

from tools.context import Context
from tools.envs import make_dispatch
from tools.utils import dim_action_space

import os

#randomdict, networkx
#multiworld, pygame, gym 0.15.7
def main():

    save_dir=args.save+datetime.datetime.now().strftime("%Y-%m-%d_%H_%M_%S")+'_'+str(args.seed)+"-"+args.exp_name

    ###Directories and loggers for saving and prepare cuda or cpu
    logger,coord_logger,dir_models,eval_logger = utils.create_directories(save_dir,args.seed,args)
    gym.logger.set_level(40)
    utils.prepare_cuda(args)
    context = Context(logger,coord_logger,save_dir=save_dir,path_models=dir_models,args=args)
    context.eval_logger=eval_logger

    ######Create environment
    # if args.env_type == "gridworld":
    #     envs = make_dispatch(args.env_type, args.env_name,1, "cpu", max_steps=args.max_steps,
    #                          random_pos=True,seed=args.seed, video=False, args=args)
    # else:
    envs = make_dispatch(args.env_type,args,save_dir,video=False,reset=(False if args.env_name == "sawyer_door" else True))
    context.action_size=dim_action_space(envs.action_space)

    ###Observation space for interacting agent
    if not isinstance(envs.observation_space,gym.spaces.Dict):
        sub_observation_space,sub_obs_input_size = envs.observation_space,envs.observation_space.shape[0]
    else:
        sub_observation_space,sub_obs_input_size = envs.observation_space["observation"],envs.observation_space["observation"].shape[0]

    ###Goal space for coordinator
    goal_space = gym.spaces.Box(np.array([-1]*args.num_latents),np.array([1]*args.num_latents),dtype=np.float32) #not used

    ### We use the embedded state to input in the DRL in robotic setting
    obs_input_size = args.num_latents if args.embed_sac or args.image else  sub_obs_input_size

    ###Oegn and buffer initialization
    rollouts = GngManager(args,sub_observation_space.shape, envs.action_space, goal_space,context=context,save_dir=save_dir)

    ### Goal conditioned policy initialization
    goal_input_size=dim_action_space(goal_space) if args.type == 0 else 0
    # if args.env_type == "gridworld":
    #     from evaluators import RandomPolicy
    #     sac=RandomPolicy(envs.action_space,1)
    # else:
    critic = MLP(num_inputs=obs_input_size + goal_input_size + dim_action_space(envs.action_space), num_output=1,
                 activation=nn.ELU,
                 num_layers=args.layer_sac, hidden_size=args.neurons_sac).to(args.device)
    critic2 = MLP(num_inputs=obs_input_size + goal_input_size + dim_action_space(envs.action_space), num_output=1,
                  activation=nn.ELU,
                  num_layers=args.layer_sac, hidden_size=args.neurons_sac).to(
        args.device) if args.double_critic else None
    actor_mlp = MLP(num_inputs=obs_input_size + goal_input_size, num_output=args.neurons_sac,
                    hidden_size=args.neurons_sac, activation=nn.ELU, num_layers=args.layer_sac, last_linear=False)
    actor = Actor_SAC(actor_mlp, args.neurons_sac, dim_action_space(envs.action_space), envs.action_space, args).to(
        args.device)

    sac=ContSAC(args,envs.action_space,critic,actor,context,critic2=critic2)


    ######Initialization of models
    model_subpolicy = SubAgent(args,envs,sac,context,rollouts, gamma=args.gamma, log_interval=args.log_interval)

    ### Representation learner initialization
    input_discriminator=sub_obs_input_size
    #We give only the top view to the agent in maze environment
    context.start_discriminator = 0
    if args.env_type == "maze" and "Point" in args.env_name:
        input_discriminator -= 6
        context.start_discriminator = 6
    elif args.env_type == "maze" and "Ant" in args.env_name:
        input_discriminator -= 29
        context.start_discriminator = 29

    if not (args.image and args.env_type == "multiworld"):
        net_mlp = MLP(num_inputs=input_discriminator, num_output=dim_action_space(goal_space), activation=nn.ELU,num_layers=args.layer_pred, hidden_size=args.neurons_pred,args=args).to(args.device)
    else:
        net_mlp = SkewfitConvNet(dim_action_space(goal_space),args=args).to(args.device)

    distop_info= DistopInfo(net_mlp,args,context,rollouts)

    ### HIgh level policy
    if args.plan_interval != -1:
        coord_policy = DijkstraBandit(rollouts.oegn, args, goal_space, save_dir=save_dir,context=context)
    else:
        coord_policy = BanditObs(rollouts.oegn, args, goal_space, save_dir=save_dir,context=context)
    sublearner = Learner(model_subpolicy,[distop_info],delay=args.delay,nlearn=args.sub_nlearn,args=args)

    ###Shared variables, not clean but simpler
    context.model_estimator = distop_info
    context.estimator=distop_info
    context.rollouts = rollouts
    context.mpolicy = model_subpolicy
    context.actor_critic=model_subpolicy.algo
    context.goal_size=dim_action_space(goal_space)
    context.obs_size=sub_obs_input_size
    context.coord_actor=coord_policy

    model_coordination = CoordAgent(sublearner,context,args,coord_policy,rollouts,log_interval = args.log_interval//args.max_steps)
    coordLearner=LearnerAlone(model_coordination, [],nlearn=1)

    coordLearner.load()

    #####Learning and epoch evaluations
    if args.env_type == "gridworld":
        total_epochs = 10
    else:
        total_epochs=50
    for i in range(total_epochs):
        coordLearner.reset(render=False)
        # coordLearner.reset(render=True)
        sublearner.train()
        _, _, _, _, infos=coordLearner.step(num_updates=args.num_updates*args.delay,epoch=i+1,total_epochs=total_epochs,render=False)
        # _, _, _, _, infos=coordLearner.step(num_updates=args.num_updates*args.delay,epoch=i+1,total_epochs=total_epochs,render=True)

        ###Epochs evaluations
        if (args.env_type == "gridworld"):
            evaluate_gridworld(args.env_type, args.env_name, sublearner, save_dir + "/tmp" + str(i), goal_space,
                                 args=args, predictor=context.estimator,
                                 coordpolicy=coord_policy, all_states=(i == total_epochs - 1), context=context)

        if args.env_type == "multiworld":
            evaluate_robotic(args.env_name,args, sublearner.mpolicy.algo,context,i=i)
            if i % 10 == 1:
                multiworld_images(sublearner, save_dir+"/tmp"+str(i), args,coord_policy,goal_space,video=0)
                rollouts.oegn.log_rewards(save_dir+"/tmp"+str(i),i)
        elif args.env_type == "maze":
            if not args.state and i == total_epochs // 2:
                multiworld_images(sublearner, save_dir + "/tmp" + str(i), args, coord_policy, goal_space, video=0)
            if i%(total_epochs//10) == 0:
                if not os.path.exists(save_dir + "/tmp" + str(i)): os.mkdir(save_dir + "/tmp" + str(i))
                evaluate_maze(args, sublearner.mpolicy.algo, context,save_dir + "/tmp" + str(i))
                rollouts.oegn.log_rewards(save_dir + "/tmp" + str(i), i)
            eval_maze(sublearner,coordLearner,save_dir,number_steps=5*args.max_steps,args=args)
        elif args.env_type == "mujoco":
            if not os.path.exists(save_dir+"/tmp/"): os.mkdir(save_dir+"/tmp/")
            rollouts.oegn.log_rewards(save_dir+"/tmp",i)

    if context.save_model:
        coordLearner.save()

    if (not args.state and args.env_type == "multiworld") or args.env_type == "maze":
        multiworld_images(sublearner, save_dir + "/end_videos", args, coord_policy, goal_space, video=args.video)


if __name__ == '__main__':
    main()



